from partnet.ObjFile import ObjFile

import numpy as np
import os
from matplotlib import pyplot as plt
import imageio
from PIL import Image
from scipy.interpolate import griddata
import json
import argparse
import cv2

parser = argparse.ArgumentParser()
parser.add_argument("--root_dir", type=str, default="partnet/data")
parser.add_argument("--exp", type=str, required=True)
parser.add_argument("--num_views", type=int, default=24)
parser.add_argument("--img_size", type=int, default=400)
args = parser.parse_args()

def depth2img(depth):
    depth = (depth-depth.min())/(depth.max()-depth.min())
    depth_img = cv2.applyColorMap((depth*255).astype(np.uint8),
                                  cv2.COLORMAP_TURBO)

    return depth_img

def get_leaf_obj_files(json_path):
    """Extract .obj files from leaf nodes in the hierarchy."""
    with open(json_path, 'r') as f:
        hierarchy = json.load(f)[0]  # Single top-level object
    
    obj_files = []
    def traverse(node):
        if "objs" in node:
            obj_files.extend([f"{obj}.obj" for obj in node["objs"]])  # Add .obj extension
        elif "children" in node:
            for child in node["children"]:
                traverse(child)
    
    traverse(hierarchy)
    return obj_files


def sample_spherical_viewpoints(num_views, radius=3.0, mode='train', val_offset=30.0, seed=None):
    if seed is not None:
        np.random.seed(seed)

    elevations = np.linspace(0, 180, num_views + 2)[1:-1] 
    azimuths = np.linspace(0, 360, num_views + 1)[:-1]

    if mode == 'val':
        # Offset
        elevations += val_offset
        azimuths += val_offset
        elevations = np.mod(elevations, 180)
        azimuths = np.mod(azimuths, 360)

    elevations = np.radians(elevations)
    azimuths = np.radians(azimuths)

    # Generate camera positions
    poses = []
    for elev, azim in zip(elevations, azimuths):
        x = radius * np.cos(elev) * np.cos(azim)
        y = radius * np.cos(elev) * np.sin(azim)
        z = radius * np.sin(elev)
        pos = np.array([x, y, z])

        # Camera coordinate system
        z_axis = -pos / np.linalg.norm(pos)  # Look at origin
        x_axis = np.cross([0, 0, 1], z_axis)  # Up vector as reference
        x_axis = x_axis / np.linalg.norm(x_axis) if np.linalg.norm(x_axis) > 0 else np.array([1, 0, 0])
        y_axis = np.cross(z_axis, x_axis)
        R = np.stack([x_axis, y_axis, z_axis], axis=1)

        # 3x4 pose matrix [R | t]
        pose = np.eye(4)
        pose[:3, :3] = R
        pose[:3, 3] = pos
        poses.append(pose[:3, :4]) 

    return np.array(poses), elevations, azimuths


def render_views(root_dir, json_path, poses, elevations, azimuths, img_size=400):
    print(f"Rendering views for {root_dir}")
    obj_dir = os.path.join(root_dir, "objs")

    obj_files = get_leaf_obj_files(json_path)
    objs = [ObjFile(obj_file=os.path.join(obj_dir, obj_file)) for obj_file in obj_files]

    print(f"Loaded {len(objs)} objects")
    
    # Combine nodes and faces for RGB rendering
    combined_nodes = np.concatenate([obj.nodes[1:] for obj in objs], axis=0)  # Skip dummy node
    combined_faces = []
    offset = 0
    for obj in objs:
        faces = [[f + offset for f in face] for face in obj.faces]
        combined_faces.extend(faces)
        offset += len(obj.nodes) - 1
    combined_obj = ObjFile()
    combined_obj.nodes = np.concatenate([[[0.0, 0.0, 0.0]], combined_nodes], axis=0)  # Add dummy node
    combined_obj.faces = combined_faces

    nmin, nmax = combined_obj.MinMaxNodes()
    xlim = (nmin[0], nmax[0])
    ylim = (nmin[1], nmax[1])
    zlim = (nmin[2], nmax[2])

    images = []
    seg_maps = {obj_file: [] for obj_file in obj_files}
    depth_maps = []
    
    fov = 60 
    fx = fy = img_size / (2 * np.tan(np.radians(fov / 2)))
    K = np.array([[fx, 0, img_size/2],
                  [0, fy, img_size/2],
                  [0,  0,   1]], dtype=np.float32)

    tmp_dir = os.path.join(root_dir, "tmp")
    os.makedirs(tmp_dir, exist_ok=True)
    for i, (elev, azim, pose) in enumerate(zip(elevations, azimuths, poses)):
        rgb_path = os.path.join(tmp_dir, f"temp_rgb_{i:03d}.png")
        combined_obj.Plot(
            output_file=rgb_path,
            pose=pose,
            width=img_size,
            height=img_size,
            xlim=xlim,
            ylim=ylim,
            zlim=zlim
        )
        rgb = np.array(Image.open(rgb_path).convert("RGB"))
        images.append(rgb)
        
        mask = np.any(rgb > 0, axis=2)

        # Render segmentation for each part
        for j, (obj_file, obj) in enumerate(zip(obj_files, objs)):
            seg_path = os.path.join(tmp_dir, f"temp_seg_{i:03d}_{j}.png")
            obj.Plot(
                output_file=seg_path,
                pose=pose,
                width=img_size,
                height=img_size,
                xlim=xlim,
                ylim=ylim,
                zlim=zlim
            )
            seg = np.array(Image.open(seg_path).convert("L"))
            seg = (seg > 0).astype(np.uint8) * 255
            seg_maps[obj_file].append(seg)
            os.remove(seg_path)

        depth_path = os.path.join(root_dir, f"temp_depth_{i:03d}.npy")
        combined_obj.PlotDepth(
            output_file=depth_path,
            pose=pose,
            K=K,
            width=img_size,
            height=img_size,
            xlim=xlim,
            ylim=ylim,
            zlim=zlim,
            mask=mask
        )
        depth_map = np.load(depth_path)
        depth_maps.append(depth_map)
        os.remove(depth_path)
        os.remove(rgb_path)

    return images, seg_maps, depth_maps, K


def generate_hierarchy(root_dir, json_path, img_size, num_views=24):
    with open(json_path, 'r') as f:
        hierarchy = json.load(f)[0]

    # Collect all nodes with their parent relationships and identify leaves
    node_info = {}
    leaf_nodes = {}
    def traverse(node, parent_id=None):
        node_id = node["id"]
        has_objs = "objs" in node
        node_info[node_id] = {"parent": parent_id, "has_objs": has_objs}
        if has_objs:
            leaf_nodes[node_id] = {"objs": node["objs"]}
        if "children" in node:
            for child in node["children"]:
                traverse(child, node_id)
    
    traverse(hierarchy)
    
    # Map .obj files to their leaf node IDs
    obj_to_id = {}
    for leaf_id, leaf in leaf_nodes.items():
        for obj in leaf["objs"]:
            obj_to_id[f"{obj}.obj"] = leaf_id

    # Sort leaf node IDs for consistent indexing
    leaf_ids = sorted(leaf_nodes.keys())
    num_leaves = len(leaf_ids)
    leaf_id_to_idx = {leaf_id: idx for idx, leaf_id in enumerate(leaf_ids)}

    # Generate inside and same matrices for leaf nodes only
    inside = np.zeros((num_views, num_leaves, num_leaves), dtype=bool)
    same = np.zeros((num_views, num_leaves, num_leaves), dtype=bool)

    for i in range(num_views):
        for id1 in leaf_ids:
            idx1 = leaf_id_to_idx[id1]
            for id2 in leaf_ids:
                idx2 = leaf_id_to_idx[id2]
                if id1 == id2:
                    same[i, idx1, idx2] = True 
                else:
                    current_id = id2
                    while current_id is not None:
                        if current_id == id1:
                            inside[i, idx2, idx1] = True
                            break
                        current_id = node_info[current_id]["parent"]

    return inside, same, obj_to_id

def save_dataset(root_dir, images, seg_maps, depth_maps, poses, K, inside, same, obj_to_id, img_size=400):
    os.makedirs(os.path.join(root_dir, "train"), exist_ok=True)
    os.makedirs(os.path.join(root_dir, "train_seg"), exist_ok=True)
    os.makedirs(os.path.join(root_dir, "train_seg_hierarchy"), exist_ok=True)
    os.makedirs(os.path.join(root_dir, "train_depth"), exist_ok=True)
    # os.makedirs(os.path.join(root_dir, "gt_seg", "train"), exist_ok=True)

    for i, img in enumerate(images):
        imageio.imwrite(os.path.join(root_dir, "train", f"r_{i}.png"), img)

    for obj_file, masks in seg_maps.items():
        mask_id = obj_to_id[obj_file] 
        for i, mask in enumerate(masks):
            seg_dir = os.path.join(root_dir, "train_seg", f"r_{i}")
            os.makedirs(seg_dir, exist_ok=True)
            plt.imsave(os.path.join(seg_dir, f"{mask_id}.png"), mask, cmap="gray")

    for i in range(len(images)):
        np.save(os.path.join(root_dir, "train_seg_hierarchy", f"r_{i}_inside.npy"), inside[i])
        np.save(os.path.join(root_dir, "train_seg_hierarchy", f"r_{i}_same.npy"), same[i])

    for i, depth in enumerate(depth_maps):
        np.save(os.path.join(root_dir, "train_depth", f"r_{i}_depth.npy"), depth)

        # finite_depth = depth[~np.isinf(depth)]
        # if len(finite_depth) > 0:
        #     dmin, dmax = finite_depth.min(), finite_depth.max()
        #     depth_normalized = np.zeros_like(depth, dtype=np.float32)
        #     mask_finite = ~np.isinf(depth)
        #     # depth_normalized[mask_finite] = 1.0 - (depth[mask_finite] - dmin) / (dmax - dmin)
        #     depth_normalized[mask_finite] = depth[mask_finite]
        #     depth_normalized = np.clip(depth_normalized, 0, 1)
        #     depth_img = (255 * depth_normalized).astype(np.uint8)
        # else:
        #     depth_img = np.zeros((img_size, img_size), dtype=np.uint8)

        # fig = plt.figure(facecolor='black')
        # ax = fig.add_subplot(111)
        # ax.imshow(depth_img, cmap='gray', vmin=0, vmax=255)
        # ax.axis('off')
        # fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
        # plt.savefig(os.path.join(root_dir, "train_depth", f"r_{i}_depth.png"), 
        #             dpi=img_size/8, facecolor='black', transparent=False)
        # plt.close(fig)

    transforms = {
        "camera_angle_x": float(2 * np.arctan(img_size / (2 * K[0, 0]))),
        "frames": [
            {"file_path": f"train/r_{i}", "transform_matrix": pose.tolist()}
            for i, pose in enumerate(poses)
        ]
    }
    with open(os.path.join(root_dir, "transforms_train.json"), 'w') as f:
        json.dump(transforms, f, indent=2)


if __name__ == '__main__':
    # modes = ['train', 'val']
    modes = ['val']
    
    data_dir = os.path.join(args.root_dir, args.exp)
    json_path = os.path.join(data_dir, "result.json")
    inside, same, obj_to_id = generate_hierarchy(data_dir, json_path, img_size=args.img_size)
    
    if 'train' in modes:
        train_poses, train_elevs, train_azims = sample_spherical_viewpoints(24, mode='train')
        images, seg_maps, depth_maps, K = render_views(data_dir, json_path, train_poses, train_elevs, train_azims)
        save_dir = os.path.join(args.root_dir, args.exp+"_converted")
        save_dataset(save_dir, images, seg_maps, depth_maps, train_poses, K, inside, same, obj_to_id, img_size=args.img_size)

    if 'val' in modes:
        val_poses, val_elevs, val_azims = sample_spherical_viewpoints(8, mode='val', val_offset=30.0)
        images, seg_maps, depth_maps, K = render_views(data_dir, json_path, val_poses, val_elevs, val_azims)
        save_dir = os.path.join(args.root_dir, args.exp+"_val_converted")
        save_dataset(save_dir, images, seg_maps, depth_maps, val_poses, K, inside, same, obj_to_id, img_size=args.img_size)